#coding=utf-8

from dcnn import *
import numpy
import logging

#test stanford sentiment treebank classification task (positive/ negative)
#you should run python prepare.py in the data/stanford direction firstly

def test_stanford_positive_negative():

    total_data_file = '../data/stanford/total.data'
    total_sentences = LineSentence(total_data_file, repeat=5)

    train_data_file = '../data/stanford/train2.data'
    train_label_file = '../data/stanford/train2.label'
    train_sentences = LineSentence(train_data_file)
    train_labels = numpy.fromfile(train_label_file, sep='\n', dtype=numpy.int32)

    dev_data_file = '../data/stanford/dev2.data'
    dev_label_file = '../data/stanford/dev2.label'
    dev_sentences = LineSentence(dev_data_file)
    dev_labels = numpy.fromfile(dev_label_file, sep='\n', dtype=numpy.int32)

    test_data_file = '../data/stanford/test2.data'
    test_label_file = '../data/stanford/test2.label'
    test_sentences = LineSentence(test_data_file)
    test_labels = numpy.fromfile(test_label_file, sep='\n', dtype=numpy.int32)

    # n_filters=[6,14] in the paper
    # n_filters=[4,6] in LeNet
    model = DCNNDeep(sentences=train_sentences, output_layer_size=2, wordvec_dim=48, alpha=0.01, entropy_descent_m=0.995,
                         dropout_rate_in_hiddens=0.5, dropout_rate_in_input=0.2, min_count=2, full_con_layer_size=5,
                         filter_width=[7,5], k_top=4, n_filters=[4,6], alpha_m=0.999995, min_alpha=0.00001,
                         pre_train_word_vec=True, pre_train_sentences=total_sentences, workers=1)
    model.train(train_sentences=train_sentences, train_labels=train_labels, patience=5,
                validate_freq=200, max_entropy_allowed=0.38,
                validate_sentences=dev_sentences, validate_labels=dev_labels, chunksize=50)
    print 'test accuracy: %f' %model.accuracy(test_sentences, test_labels)

def test_douban_positive_negative():

    train_data_file = '../data/douban/train2.data'
    train_label_file = '../data/douban/train2.label'
    train_sentences = LineSentence(train_data_file)
    train_labels = numpy.fromfile(train_label_file, sep='\n', dtype=numpy.int32)

    dev_data_file = '../data/douban/dev2.data'
    dev_label_file = '../data/douban/dev2.label'
    dev_sentences = LineSentence(dev_data_file)
    dev_labels = numpy.fromfile(dev_label_file, sep='\n', dtype=numpy.int32)

    test_data_file = '../data/douban/test2.data'
    test_label_file = '../data/douban/test2.label'
    test_sentences = LineSentence(test_data_file)
    test_labels = numpy.fromfile(test_label_file, sep='\n', dtype=numpy.int32)

    # n_filters=[6,14] in the paper
    # n_filters=[4,6] in LeNet
    model = DCNNDeep(sentences=train_sentences, output_layer_size=2, wordvec_dim=36, alpha=0.01,
                     dropout_rate_in_hiddens=0.5, dropout_rate_in_input=0.2, min_count=2,
                     filter_width=[7,5], k_top=4, n_filters=[2,2], alpha_m=0.99999)
    model.train(train_sentences=train_sentences, train_labels=train_labels, patience=10,
                validate_sentences=dev_sentences, validate_labels=dev_labels, chunksize=5)
    print 'test accuracy: %f' %model.accuracy(test_sentences, test_labels)

if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
    test_stanford_positive_negative()